#!/usr/bin/python2.4
# vim:ts=4:sw=4:softtabstop=4:smarttab:expandtab
# 
#    Copyright (C) 1999-2006  Keith Dart <keith@kdart.com>
#
#    This library is free software; you can redistribute it and/or
#    modify it under the terms of the GNU Lesser General Public
#    License as published by the Free Software Foundation; either
#    version 2.1 of the License, or (at your option) any later version.
#
#    This library is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#    Lesser General Public License for more details.

"""
Compile module compiles SMI data into Python objects for use by the SNMP
module.  This started out clean, but now it's ugly. But at least it spits
out something useful. 

"""


import os
import py_compile
from string import translate, maketrans

from pycopia.SMI import SMI, Basetypes, Objects

USERMIBPATH = os.environ.get("USERMIBPATH", os.path.join("/", "var", "tmp", "mibs"))

# global name translation table
# Since we convert MIB modules to Python modules, we can't have a dash in
# the name. These are translated to underscores. 
TRANSTABLE = maketrans("-", "_")
def convert_name(name):
    return translate(name, TRANSTABLE)

# These are some of the attributes that the SNMP module needs, and are
# exported "as-is". Other attributes are special-cased in the appropriate
# generator method.
EXPORTS = {
    "Type": ["status", "format", "units", "ranges", "enumerations"],
    "Node": ["access", "create", "status", "units"],
    "Macro": ["name", "status"],
    "Module": ["name", "path", "conformance", "language", "description"],
    "Group": ["name", "status"],
    "Value": ["val"],
}

# objects directly imported from SMI.Objects in the mib modules
IMPORTED_OBJECTS = ["ColumnObject", "MacroObject", "NotificationObject", 
    "RowObject", "ScalarObject", "NodeObject", "ModuleObject", "GroupObject"]

def _classstr(tup):
    def _cstr(c):
        if type(c) is str:
            return c
        else:
            if c.__name__ in IMPORTED_OBJECTS:
                return c.__name__
            else:
                return "%s.%s" % (c.__module__, c.__name__)
    return ", ".join(map(_cstr, tup))

# generic class producer. Returns source code string
def genClass(sminode, baseclass, attrdict=None, doc=None):
    if not attrdict:
        attrdict = {}
    for attrname in EXPORTS[sminode.__class__.__name__]:
        val = getattr(sminode, attrname)
        if val is None:
            continue
        if type(val) is str:
            attrdict[attrname] = repr(val)
        else:
            attrdict[attrname] = val
    klassname = convert_name(sminode.name)
    parents = (baseclass,)
    s = []
    if parents:
        s.append( "class %s(%s):" % (klassname, _classstr(parents)) )
    else:
        s.append( "class %s(object):" % (klassname) )
    if doc:
        s.append('\t"""%s"""' % doc)
    for key, val in attrdict.items():
        if val:
            s.append( "\t%s = %s" % (key, val) )
    if len(s) == 1:
        s.append("\tpass")
    s.append("\n")
    return "\n".join(s)


# generates a repr for SMI.Objects.IndexObjects
class IndexGenerator(list):
    def __init__(self, init=None, implied=False):
        super(IndexGenerator, self).__init__(init or [])
        self.implied = bool(implied)
    def __repr__(self):
        lv = ", ".join(self)
        return "pycopia.SMI.Objects.IndexObjects([%s], %r)" % (lv, self.implied)

class ListGenerator(list):
    def __init__(self, init=None):
        super(ListGenerator, self).__init__(init or [])
    def __repr__(self):
        return "[%s]" % (", ".join(self), )


class ObjectSourceGenerator(object):
    """
Usage: ObjectSourceGenerator(fileobject, modulename)

Parameters:
    fileobject = A file-type object.
    modulename  = An SMI module name.

    """
    def __init__(self, fo, oidfo, smimodule):
        self.smimodule = smimodule
        self.fo = fo
        self.oidfo = oidfo
        self.pymodname = convert_name(smimodule.name)
        #self.tempmodule = new.module(self.pymodname)
        self.imports = {}
        self.fo.write("""# python
# This file is generated by a program (mib2py). Any edits will be lost.

from pycopia.aid import Enum
import pycopia.SMI.Basetypes
Range = pycopia.SMI.Basetypes.Range
Ranges = pycopia.SMI.Basetypes.Ranges

from pycopia.SMI.Objects import %s

""" % (", ".join(IMPORTED_OBJECTS),))

        self.oidfo.write("""# python
# This file is generated by a program (mib2py). 

import %s

OIDMAP = {
""" % (self.pymodname))

    def finalize(self):
        self.oidfo.write("}\n")
        handle_specials(self.fo, self.smimodule)
        self.fo.write("""
# Add to master OIDMAP.
from pycopia import SMI
SMI.update_oidmap(__name__)
""")

    def add_comment(self, text):
        self.fo.write("# %s\n" % text)

    def genImports(self):
        self.fo.write("# imports \n")
        for node in self.smimodule.get_imports():
            if not self.imports.has_key(node.module):
                self.imports[node.module] = []
            self.imports[node.module].append(node.name)
        for modname, implist in self.imports.items():
            impnames = map(lambda s: convert_name(s), implist)
            self.fo.write("from %s import %s\n" % (convert_name(modname), ", ".join(impnames)))
        self.fo.write("\n")

    def genModule(self):
        self.fo.write(genClass(self.smimodule, Objects.ModuleObject))

    def genTypes(self):
        self.fo.write("# types \n")
        for smi_type in self.smimodule.get_types():
            name = convert_name(smi_type.name)
            if hasattr(Basetypes, name ):
                self.fo.write("%s = pycopia.SMI.Basetypes.%s\n" % (name, name))
            else:
                self.fo.write("\n")
                if smi_type.snmptype:
                    baseclass = getattr(Basetypes, smi_type.snmptype)
                    self.fo.write(genClass(smi_type, baseclass))

    def genNodes(self):
        self.fo.write("# nodes\n")
        for node in self.smimodule.get_nodes(SMI.SMI_NODEKIND_NODE):
            if node.name:
                initdict = {}
                initdict["name"] = repr(node.name)
                initdict["OID"] = repr(Basetypes.ObjectIdentifier(node.OID))
                self.fo.write(genClass(node, Objects.NodeObject, initdict))
                self._genOIDItem(node.OID, node.name)
        self.fo.write("\n")

    def genScalars(self):
        self.fo.write("# scalars \n")
        for scalar in self.smimodule.get_scalars():
            if scalar.status not in \
                    (SMI.SMI_STATUS_DEPRECATED, 
                    SMI.SMI_STATUS_CURRENT, 
                    SMI.SMI_STATUS_MANDATORY):
                continue # do not expose optional or obsolete objects
            initdict = {}
            initdict["syntaxobject"] = so = self._getSyntax(scalar)
            if so.find("Enumeration") >= 0:
                initdict["enumerations"] = scalar.syntax.enumerations
            initdict["OID"] = repr(Basetypes.ObjectIdentifier(scalar.OID))
            self.fo.write(genClass(scalar, Objects.ScalarObject, initdict))
            self.fo.write("\n")
            self._genOIDItem(scalar.OID, scalar.name)


    def genColumns(self):
        self.fo.write("# columns\n")
        for col in self.smimodule.get_columns():
            initdict = {}
            initdict["syntaxobject"] = so = self._getSyntax(col)
            if so.find("Enumeration") >= 0:
                initdict["enumerations"] = col.syntax.enumerations
            initdict["OID"] = repr(Basetypes.ObjectIdentifier(col.OID))
            self.fo.write(genClass(col, Objects.ColumnObject, initdict))
            self.fo.write("\n")
            self._genOIDItem(col.OID, col.name)

    def genRows(self):
        self.fo.write("# rows \n")
        for row in self.smimodule.get_rows():
            if row.status not in (SMI.SMI_STATUS_DEPRECATED, 
                        SMI.SMI_STATUS_CURRENT, 
                        SMI.SMI_STATUS_MANDATORY):
                continue
            initdict = {}
            columns = "{%s}" % ", ".join(map(lambda s: "%r: %s" % (s, s), self._get_colnames(row)))
            initdict["columns"] = columns
            #initdict["index"] = repr(row.index)
            initdict["index"] = self._genIndexObjects(row)
            rowstatus = row.rowstatus
            if rowstatus:
                initdict["rowstatus"] = row.rowstatus.name
            initdict["OID"] = repr(Basetypes.ObjectIdentifier(row.OID))
            self.fo.write(genClass(row, Objects.RowObject, initdict))
            self.fo.write("\n")

    def genMacros(self):
        self.fo.write("# macros\n")
        for node in self.smimodule.get_macros():
            self.fo.write(genClass(node, Objects.MacroObject))
            self.fo.write("\n")

    def genNotifications(self):
        self.fo.write("# notifications (traps) \n")
        for notif in self.smimodule.get_notifications():
            initdict = {"OID": repr(Basetypes.ObjectIdentifier(notif.OID))}
            self.fo.write(genClass(notif, Objects.NotificationObject, initdict))
            self._genOIDItem(notif.OID, notif.name)

    def genGroups(self):
        self.fo.write("# groups \n")
        for group in self.smimodule.get_groups():
            if group.status not in (SMI.SMI_STATUS_CURRENT, 
                        SMI.SMI_STATUS_DEPRECATED, 
                        SMI.SMI_STATUS_MANDATORY):
                continue
            initdict = {}
            initdict["OID"] = repr(Basetypes.ObjectIdentifier(group.OID))
            grouplist = []
            for el in group.get_elements():
                n = el.get_node()
                grouplist.append(n.name)
            initdict["group"] = "[%s]" % ", ".join(grouplist)
            self.fo.write(genClass(group, Objects.GroupObject, initdict))
            self._genOIDItem(group.OID, group.name)

    def genCompliances(self):
        self.fo.write("# compliances \n")
        for comp in self.smimodule.get_compliances():
            if comp.status not in (SMI.SMI_STATUS_CURRENT, 
                            SMI.SMI_STATUS_DEPRECATED, 
                            SMI.SMI_STATUS_MANDATORY):
                continue
            initdict = {}
            mandlist = ListGenerator()
            for el in comp.get_elements():
                mandlist.append(el.get_node().name)
            initdict["mandatory_group"] = mandlist
            refs = ListGenerator()
            for ref in comp.get_refinements(): # XXX
                if ref.syntax: 
                    n = ref.get_node()
                    refs.append(self._getSyntax(ref)) # XXX
                initdict["refinements"] = repr(refs)
            self.fo.write(genClass(comp, Objects.Compliance, initdict))
            self.fo.write("\n")

    def genCapabilities(self):
        self.fo.write("# capabilities \n")
        for cap in self.smimodule.get_capabilities():
            if cap.status not in (SMI.SMI_STATUS_CURRENT, 
                        SMI.SMI_STATUS_DEPRECATED, 
                        SMI.SMI_STATUS_MANDATORY):
                continue
            initdict = {}
            # XXX
            self.fo.write(genClass(cap, Objects.Capability, initdict))
            self.fo.write("\n")

    # utility methods
    def _get_colnames(self, row):
        rv = []
        for c in row.get_children():
            if c.nodekind == SMI.SMI_NODEKIND_COLUMN:
                rv.append(c.name)
        return rv

    def _genOIDItem(self, OID, classname):
        self.oidfo.write('%r: %s.%s,\n' % (str(OID), self.pymodname, convert_name(classname)))

    def _genIndexObjects(self, smirow):
        index = smirow.get_index()
        if index is None: # old, old v1 MIBS with no index
            return
        gen = IndexGenerator(implied=index.implied)
        for n in index:
            gen.append(n.name)
        if smirow.indexkind == SMI.SMI_INDEX_AUGMENT:
            for node in index:
                mod = node.get_module()
                self.fo.write("from %s import %s\n" % (convert_name(mod.name), node.name))
        return repr(gen)

    def _getSyntax(self, node):
        syntax = node.syntax
        if syntax is None:
            print "***** unable to get SYNTAX for node %s" % (node.name)
            return "UNKNOWN"
        if not syntax.name:
            syntax = syntax.get_parent()
        syntaxname = syntax.name
        if not syntaxname:
            syntaxname = syntax.snmptype
        if hasattr(Objects, syntaxname):
            cl = getattr(Objects, syntaxname)
            return "%s.%s" % (cl.__module__, cl.__name__)
        elif hasattr(Basetypes, syntaxname):
            cl = getattr(Basetypes, syntaxname)
            return "%s.%s" % (cl.__module__, cl.__name__)
        # else must be a locally defined type.
        return syntaxname


    def genAll(self):
        self.genImports()
        self.genModule()
        self.genNodes()
        self.genMacros()
        self.genTypes()
        self.genScalars()
        self.genColumns()
        self.genRows()
        self.genNotifications()
        self.genGroups()
        #self.genCompliances()
        self.genCapabilities()
        self.finalize()


# some modules require special handling. Crude, hopefully temporary, hack

def handle_specials(fo, smimodule):
    fo.write("\n# special additions\n")
    handler = {'SNMPv2-SMI': _handle_smi,
            'SNMPv2-TC': _handle_tc}.get(smimodule.name, _handle_default)
    handler(fo, smimodule)

def _handle_smi(fo, mod):
    fo.write("\n")
    for name in ("ObjectSyntax", "SimpleSyntax", "ApplicationSyntax"):
        fo.write("%s = pycopia.SMI.Basetypes.%s\n" % (name, name))

def _handle_tc(fo, mod):
    fo.write("\n")
    for name in ("Bits", "BITS"):
        fo.write("%s = pycopia.SMI.Basetypes.%s\n" % (name, name))

def _handle_default(fo, mod):
    pass


def _compile_module(smimodule):
    if not smimodule.name:
        return # unnamed from where?
    fname = os.path.join(USERMIBPATH, convert_name(smimodule.name)+".py")
    oidfname = os.path.join(USERMIBPATH, convert_name(smimodule.name)+"_OID.py")
    if not os.path.exists(fname):
        print "Compiling module", smimodule.name
        fd = open(fname, "w")
        oidfd = open(oidfname, "w")
        generator = ObjectSourceGenerator(fd, oidfd, smimodule)
        generator.genAll()
        fd.close()
        try:
            py_compile.compile(fname)
        except Exception, err:
            print "***", err
    else:
        print "    +++ file %r exists, skipping." % (fname, )

def compile_module(modname, preload=None, all=False):
    if preload:
        for pm in preload:
            SMI.load_module(pm)
    smimodule = SMI.get_module(modname)
    if not smimodule:
        print "Could not load module", modname
        return
    if all:
        for dep in _get_dependents(smimodule):
            _compile_module(SMI.get_module(dep))
    _compile_module(smimodule)

def _get_dependents(module, hash=None):
    h = hash or {}
    for imp in module.get_imports():
        h[imp.module] = True
        _get_dependents(SMI.get_module(imp.module), h)
    return h.keys()

def compile_everything(all=False):
    count = 0
    paths = SMI.get_path().split(":")
    for dir in paths:
        print "Looking in", dir
        for modname in os.listdir(dir):
            modpath = os.path.join(dir, modname)
            if os.path.isfile(modpath):
                print "Found module", modname, "compiling..."
                try:
                    compile_module(modname, None, all)
                except SMI.SmiError, err:
                    print "***[", err, "]***"
                count += 1
            SMI.clear() # clear out mememory
            SMI.init()
    print "Found and compiled %d MIBS." % (count, )


if __name__ == "__main__":
    compile_everything(True)


